"""
@author: Zhenjiao Jiang <zhenjiao.jiang@csiro.au>, October 2018.
"""

import numpy as np
import tensorflow as tf
K_size = 25
L_size  = 5                         # basic image size in decoder in vertical direction
leak_value = 0.2

outsize_d  = 128                    # base outsize (width) of first layer in discriminor
outsize_e  = 64 #256                # base outsize (width) of first layer in encoder

#%%  outlayer size
#outsize_g=64
outsize_g  = 128                    # base outsize (width) of first layer in decoder 
#outsize_g=256

#%%
def encoder(inputs, phase_train=True, reuse=False,training=False):
    
    flt_size   = 4
    flt_size_v = 2

    stride1    = [1,2,2,2,1]
    stride2    = [1,4,4,1,1]
    stride3    = [1,1,1,1,1]
  
    xavier_init = tf.contrib.layers.xavier_initializer()
    zero_init = tf.zeros_initializer()
    
    with tf.variable_scope("encoder_layer1",reuse=reuse):
        we1 = tf.get_variable("we1", shape=[flt_size, flt_size, flt_size_v, 1, outsize_e], initializer=xavier_init)
        be1 = tf.get_variable("be1", shape=[outsize_e], initializer=zero_init)
        d_1 = tf.nn.conv3d(inputs, we1, strides=stride2, padding="SAME")
        d_1 = tf.nn.bias_add(d_1, be1)
        d_1 = lrelu(d_1, leak_value)              # input 1 800*800*10 1,  d1 1, 200 200 10, 64
        
    with tf.variable_scope("encoder_layer2",reuse=reuse):        
        we2 = tf.get_variable("we2", shape=[flt_size, flt_size, flt_size_v, outsize_e, int(outsize_e/2)], initializer=xavier_init)
        be2 = tf.get_variable("be2", shape=[int(outsize_e/2)], initializer=zero_init)
        d_2 = tf.nn.conv3d(d_1, we2, strides=stride2, padding="SAME") 
        d_2 = tf.nn.bias_add(d_2, be2)
        d_2 = lrelu(d_2, leak_value)              # d2  1, 50 50 10, 32

    with tf.variable_scope("encoder_layer3",reuse=reuse):          
        we3 = tf.get_variable("we3", shape=[flt_size, flt_size, flt_size_v, int(outsize_e/2), int(outsize_e/4)], initializer=xavier_init)
        be3 = tf.get_variable("be3", shape=[int(outsize_e/4)], initializer=zero_init)
        d_3 = tf.nn.conv3d(d_2, we3, strides=stride1, padding="SAME")  
        d_3 = tf.nn.bias_add(d_3, be3)            
        d_3 = lrelu(d_3, leak_value)              # d3  1, 25 25 5, 16

    with tf.variable_scope("encoder_layer4",reuse=reuse):          
        we4 = tf.get_variable("we4", shape=[flt_size, flt_size, flt_size_v, int(outsize_e/4), 1], initializer=xavier_init)
        be4 = tf.get_variable("be4", shape=[1], initializer=zero_init)
        d_4 = tf.nn.conv3d(d_3, we4, strides=stride3, padding="SAME")  
        d_4 = tf.nn.bias_add(d_4, be4)            # d3  1, 25 25 5, 1       
        shape = d_4.get_shape().as_list()
        dim = np.prod(shape[1:])
        d_5 = tf.reshape(d_4, shape=[-1, dim])    # d4 1, z_size=3125  (25,25,5) for decoder

    with tf.variable_scope("encoder_out",reuse=reuse):   
        z_latent=d_5
        m=tf.nn.moments(z_latent,axes=[1])
        mean=m[0]
        stddev=m[1]

    return z_latent, mean, stddev 

#%%
def generater(z, batch_size_new, phase_train=True, reuse=False, training=False):
# filter size sensitivity analysis    
    #flt_size   = 5
    flt_size   = 5
    flt_size_v = 2

    strides4    = [1,4,4,1,1]
    strides3    = [1,2,2,1,1]
    strides2    = [1,1,1,2,1]
    strides1    = [1,1,1,1,1]
    
    xavier_init = tf.contrib.layers.xavier_initializer()
    zero_init = tf.zeros_initializer()

    g_0=tf.reshape(z,[batch_size_new,25,25,5,1]) 
    
    with tf.variable_scope("generator_layer1",reuse=reuse): 
        wg1 = tf.get_variable("wg1", shape=[flt_size, flt_size, flt_size_v, outsize_g, 1], initializer=xavier_init)        
        bg1 = tf.get_variable("bg1", shape=[outsize_g], initializer=zero_init)        
        g_1 = tf.nn.conv3d_transpose(g_0, wg1, (batch_size_new,K_size,K_size,L_size*2,outsize_g), strides=strides2, padding="SAME")
        g_1 = tf.nn.bias_add(g_1, bg1)
        g_1 = tf.contrib.layers.batch_norm(g_1, is_training=phase_train, trainable=training)
        g_1 = lrelu(g_1, leak_value)                           # output 1, 25 25 10 1   128 

    with tf.variable_scope("generator_layer2",reuse=reuse):  
        wg4 = tf.get_variable("wg2", shape=[flt_size, flt_size, flt_size_v, int(outsize_g), int(outsize_g/2)], initializer=xavier_init) 
        bg4 = tf.get_variable("bg2", shape=[int(outsize_g/2)], initializer=zero_init)         
        g_4 = tf.nn.conv3d(g_1, wg4, strides=strides1, padding="SAME")     
        g_4 = tf.nn.bias_add(g_4, bg4)
        g_4 = lrelu(g_4, leak_value)                        # out 1, 25 25 10, 1, 64
        
         
    with tf.variable_scope("generator_layer3",reuse=reuse):   
        wg5 = tf.get_variable("wg3", shape=[flt_size, flt_size, flt_size_v, int(outsize_g/4), int(outsize_g/2)], initializer=xavier_init)        
        bg5 = tf.get_variable("bg3", shape=[int(outsize_g/4)], initializer=zero_init)        
        g_5 = tf.nn.conv3d_transpose(g_4, wg5, (batch_size_new,K_size*4,K_size*4,L_size*2,int(outsize_g/4)), strides=strides4, padding="SAME")
        g_5 = tf.nn.bias_add(g_5, bg5)
        g_5 = tf.contrib.layers.batch_norm(g_5, is_training=phase_train, trainable=training)
        g_5 = lrelu(g_5, leak_value)              # out 1, 100 100 10, 1, 32
        
        
    with tf.variable_scope("generator_layer4",reuse=reuse):  
        wg6 = tf.get_variable("wg4", shape=[flt_size, flt_size, flt_size_v, int(outsize_g/4), int(outsize_g/4)], initializer=xavier_init) 
        bg6 = tf.get_variable("bg4", shape=[int(outsize_g/4)], initializer=zero_init)         
        g_6 = tf.nn.conv3d(g_5, wg6, strides=strides1, padding="SAME")     
        g_6 = tf.nn.bias_add(g_6, bg6)
        g_6 = lrelu(g_6, leak_value)             # out 1, 100 100 10, 1, 32
        
        
    with tf.variable_scope("generator_layer5",reuse=reuse):        
        wg7 = tf.get_variable("wg5", shape=[flt_size, flt_size, flt_size_v, int(outsize_g/8), int(outsize_g/4)], initializer=xavier_init)  
        bg7 = tf.get_variable("bg5", shape=[int(outsize_g/8)], initializer=zero_init)        
        g_7 = tf.nn.conv3d_transpose(g_6, wg7, (batch_size_new,K_size*8,K_size*8,L_size*2,int(outsize_g/8)), strides=strides3, padding="SAME")
        g_7 = tf.nn.bias_add(g_7, bg7)
        g_7 = tf.contrib.layers.batch_norm(g_7, is_training=phase_train, trainable=training)    #64
        g_7 = lrelu(g_7, leak_value)            # out 1, 200 200 10, 1, 16 


    with tf.variable_scope("generator_layer6",reuse=reuse):  
        wg8 = tf.get_variable("wg6", shape=[flt_size, flt_size, flt_size_v, int(outsize_g/8), 1], initializer=xavier_init) 
        bg8 = tf.get_variable("bg6", shape=[1], initializer=zero_init)         
        g_8 = tf.nn.conv3d(g_7, wg8, strides=strides1, padding="SAME")     
        g_8 = tf.nn.bias_add(g_8, bg8)
        g_8 = tf.nn.sigmoid(g_8)               # out 1,200 200 10, 1   1

    return g_8



def lrelu(x, leak=0.2):
    return tf.maximum(x, leak*x)

#%%
def discriminator(inputs, phase_train=True, reuse=False):
    
    flt_size   = 4
    flt_size_v = 2
    
    strides1    = [1,2,2,2,1]
    strides2    = [1,2,2,1,1]
    strides3    = [1,4,4,1,1]

    xavier_init = tf.contrib.layers.xavier_initializer()
    zero_init = tf.zeros_initializer()
    
    with tf.variable_scope("discriminator_layer1",reuse=reuse):
        wd1 = tf.get_variable("wd1", shape=[flt_size, flt_size, flt_size_v, 1, outsize_d], initializer=xavier_init)
        bd1 = tf.get_variable("bd1", shape=[outsize_d], initializer=zero_init)
        d_1 = tf.nn.conv3d(inputs, wd1, strides=strides1, padding="SAME")
        d_1 = tf.nn.bias_add(d_1, bd1)
        d_1 = tf.contrib.layers.batch_norm(d_1, is_training=phase_train) 
        d_1 = lrelu(d_1, leak_value)    # out  1, 100 100 5, 1    128
        
        
    with tf.variable_scope("discriminator_layer2",reuse=reuse):        
        wd2 = tf.get_variable("wd2", shape=[flt_size, flt_size, flt_size_v, outsize_d, int(outsize_d/2)], initializer=xavier_init)
        bd2 = tf.get_variable("bd2", shape=[int(outsize_d/2)], initializer=zero_init)
        d_2 = tf.nn.conv3d(d_1, wd2, strides=strides2, padding="SAME") 
        d_2 = tf.nn.bias_add(d_2, bd2)
        d_2 = tf.contrib.layers.batch_norm(d_2, is_training=phase_train) 
        d_2 = lrelu(d_2, leak_value)         # 1, 50 50 5, 1 64

    with tf.variable_scope("discriminator_layer3",reuse=reuse):          
        wd3 = tf.get_variable("wd3", shape=[flt_size, flt_size, flt_size_v, int(outsize_d/2), int(outsize_d/4)], initializer=xavier_init)
        bd3 = tf.get_variable("bd3", shape=[int(outsize_d/4)], initializer=zero_init)
        d_3 = tf.nn.conv3d(d_2, wd3, strides=strides2, padding="SAME")  
        d_3 = tf.nn.bias_add(d_3, bd3)
        d_3 = tf.contrib.layers.batch_norm(d_3, is_training=phase_train)         
        d_3 = lrelu(d_3, leak_value)        # 1, 25 25 5, 1 32
        
    with tf.variable_scope("discriminator_layer4",reuse=reuse):          
        wd4 = tf.get_variable("wd4", shape=[flt_size, flt_size, flt_size_v, int(outsize_d/4), 1], initializer=xavier_init)   
        bd4 = tf.get_variable("bd4", shape=[1], initializer=zero_init)  
        d_4 = tf.nn.conv3d(d_3, wd4, strides=strides3, padding="SAME")     
        d_4 = tf.nn.bias_add(d_4, bd4)
        d_4 = tf.nn.sigmoid(d_4)                # 1, 7 7 5, 1 1 

        shape = d_4.get_shape().as_list()
        dim = np.prod(shape[1:])
        d_5 = tf.reshape(d_4, shape=[-1, dim])

    return d_5